{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import torch\n",
    "import torch.optim\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import Dataset,DataLoader\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from torch.autograd import Variable\n",
    "import cv2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0. 0. 0. ... 0. 0. 0.]\n",
      " [0. 0. 0. ... 0. 0. 0.]\n",
      " [0. 0. 0. ... 1. 0. 0.]\n",
      " ...\n",
      " [0. 0. 0. ... 0. 0. 0.]\n",
      " [0. 0. 0. ... 0. 0. 0.]\n",
      " [0. 0. 0. ... 1. 0. 0.]]\n",
      "(44440, 143)\n"
     ]
    }
   ],
   "source": [
    "data = pd.read_csv('styles.csv')\n",
    "\n",
    "ID = data['id'].values    #图片的名称\n",
    "\n",
    "\n",
    "label = data['articleType'].values\n",
    "\n",
    "#label = tf.one_hot(label,depth=12)\n",
    "a = []\n",
    "for i in label:\n",
    "    a.append(i)\n",
    "label = np.array(a)\n",
    "one_hot = OneHotEncoder()\n",
    "label = one_hot.fit_transform(label.reshape(-1,1)).toarray()   #进行one_hot编码。\n",
    "print(label[:10])\n",
    "print(label.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[[255., 255., 255.,  ..., 255., 255., 254.],\n",
       "          [255., 255., 254.,  ..., 254., 255., 254.],\n",
       "          [254., 255., 255.,  ..., 255., 255., 255.],\n",
       "          ...,\n",
       "          [122., 118., 160.,  ..., 255., 255., 255.],\n",
       "          [255., 255., 255.,  ..., 106., 143., 134.],\n",
       "          [133., 180., 169.,  ...,  94.,  88., 132.]],\n",
       " \n",
       "         [[120., 116., 152.,  ..., 255., 255., 255.],\n",
       "          [255., 255., 255.,  ...,  75., 106.,  96.],\n",
       "          [ 95., 136., 125.,  ...,  72.,  66., 123.],\n",
       "          ...,\n",
       "          [135., 146., 131.,  ..., 127., 125., 133.],\n",
       "          [120., 119., 129.,  ..., 251., 254., 255.],\n",
       "          [255., 255., 247.,  ..., 127., 127., 114.]],\n",
       " \n",
       "         [[120., 123., 108.,  ..., 138., 139., 151.],\n",
       "          [138., 139., 152.,  ..., 251., 254., 255.],\n",
       "          [255., 255., 246.,  ..., 134., 128., 115.],\n",
       "          ...,\n",
       "          [248., 255., 255.,  ...,  72., 111.,  97.],\n",
       "          [ 76., 117., 103.,  ...,  74.,  45.,  68.],\n",
       "          [ 61.,  38.,  68.,  ..., 253., 255., 254.]]]), 104)"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#重定义Dataset\n",
    "class DataSet(Dataset):\n",
    "    def __init__(self):\n",
    "        self.path = [os.path.join(\".\\images\\\\\",str(i)+\".jpg\") for i in ID]\n",
    "    \n",
    "    def __getitem__(self,index):\n",
    "        path = self.path[index]\n",
    "        img = cv2.imread(path)\n",
    "        img = cv2.resize(img,(128,128))\n",
    "        img = img.reshape(-1,128,128)\n",
    "        img = torch.from_numpy(img).type(torch.float32)\n",
    "        y_true = np.argmax(label[index],-1).astype(np.longlong)   #交叉熵损失函数不能使用one-hot编码，而且数据要是long类型的\n",
    "        \n",
    "        return img,y_true\n",
    "    def __len__(self):\n",
    "        return len(self.path)\n",
    "    \n",
    "dataset = DataSet()\n",
    "data_loader = DataLoader(dataset,batch_size=128,shuffle=True,drop_last=True)\n",
    "next(iter(dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "#定义网络\n",
    "class Net_work(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net_work, self).__init__()\n",
    "        self.layer1 = nn.Sequential(\n",
    "            nn.Conv2d(3,16,kernel_size=3,stride=1,padding=1),\n",
    "            nn.MaxPool2d(stride=2,kernel_size=2),     # 64 64\n",
    "            nn.BatchNorm2d(16),\n",
    "        )\n",
    "\n",
    "        self.layer2 = nn.Sequential(\n",
    "            nn.Conv2d(16,64,kernel_size=3,stride=1,padding=1),\n",
    "            nn.MaxPool2d(stride=2,kernel_size=2),  # 32 32\n",
    "            nn.BatchNorm2d(64),\n",
    "        )\n",
    "\n",
    "        self.layer3 = nn.Sequential(\n",
    "            nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1),\n",
    "            nn.MaxPool2d(stride=2,kernel_size=2),   #16 16\n",
    "            nn.BatchNorm2d(128)\n",
    "        )\n",
    "\n",
    "        self.fc1 = nn.Sequential(\n",
    "            nn.Linear(128*16*16,2048),\n",
    "            nn.Dropout(0.5),\n",
    "            nn.LeakyReLU(0.2),\n",
    "            nn.Linear(2048,256),\n",
    "            nn.Dropout(0.5),\n",
    "            nn.LeakyReLU(0.2),\n",
    "            nn.Linear(256,143),\n",
    "            nn.Softmax(dim=-1)\n",
    "        )\n",
    "    def forward(self, x):\n",
    "        x = self.layer1(x)\n",
    "        x = self.layer2(x)\n",
    "        x = self.layer3(x)\n",
    "\n",
    "        x = x.view(x.size(0),-1)\n",
    "\n",
    "        x = self.fc1(x)\n",
    "\n",
    "        return x\n",
    "\n",
    "net_work = Net_work()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Net_work(\n",
       "  (layer1): Sequential(\n",
       "    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  )\n",
       "  (layer2): Sequential(\n",
       "    (0): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  )\n",
       "  (layer3): Sequential(\n",
       "    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  )\n",
       "  (fc1): Sequential(\n",
       "    (0): Linear(in_features=32768, out_features=2048, bias=True)\n",
       "    (1): Dropout(p=0.5)\n",
       "    (2): LeakyReLU(negative_slope=0.2)\n",
       "    (3): Linear(in_features=2048, out_features=256, bias=True)\n",
       "    (4): Dropout(p=0.5)\n",
       "    (5): LeakyReLU(negative_slope=0.2)\n",
       "    (6): Linear(in_features=256, out_features=143, bias=True)\n",
       "    (7): Softmax()\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#定义损失函数和优化函数\n",
    "optimizer = torch.optim.Adam(net_work.parameters(),lr=0.005)\n",
    "error = nn.CrossEntropyLoss()\n",
    "net_work.train()    #设置为训练模式"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(4.8498, grad_fn=<NllLossBackward>)\n",
      "tensor(4.7852, grad_fn=<NllLossBackward>)\n"
     ]
    }
   ],
   "source": [
    "#开始训练\n",
    "for i in range(2):\n",
    "    for x,y in data_loader:\n",
    "        x = Variable(x)\n",
    "\n",
    "        y = Variable(y)\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        pred = net_work(x)\n",
    "\n",
    "        loss = error(pred,y)\n",
    "\n",
    "        loss.backward()\n",
    "\n",
    "        optimizer.step()\n",
    "\n",
    "        print(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(net.state_dict(),\"a.pth\")   #训练完成后保存模型"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
